balance Quickstart (CBPS): Analyzing and adjusting the bias on a simulated toy dataset¶

'balance' is a Python package that is maintained and released by the Core Data Science Tel-Aviv team in Meta. 'balance' performs and evaluates bias reduction by weighting for a broad set of experimental and observational use cases.

Although balance is written in Python, you don't need a deep Python understanding to use it. In fact, you can just use this notebook, load your data, change some variables and re-run the notebook and produce your own weights!

This quickstart demonstrates re-weighting specific simulated data, but if you have a different usecase or want more comprehensive documentation, you can check out the comprehensive balance tutorial.

Analysis¶

There are four main steps to analysis with balance:

  • load data
  • check diagnostics before adjustment
  • perform adjustment + check diagnostics
  • output results

Let's dive right in!

Example dataset¶

The following is a toy simulated dataset.

In [1]:
from balance import load_data
INFO (2024-11-25 18:55:26,428) [__init__/<module> (line 54)]: Using balance version 0.9.1
In [2]:
target_df, sample_df = load_data()

print("target_df: \n", target_df.head())
print("sample_df: \n", sample_df.head())
target_df: 
        id gender age_group     income  happiness
0  100000   Male       45+  10.183951  61.706333
1  100001   Male       45+   6.036858  79.123670
2  100002   Male     35-44   5.226629  44.206949
3  100003    NaN       45+   5.752147  83.985716
4  100004    NaN     25-34   4.837484  49.339713
sample_df: 
   id  gender age_group     income  happiness
0  0    Male     25-34   6.428659  26.043029
1  1  Female     18-24   9.940280  66.885485
2  2    Male     18-24   2.673623  37.091922
3  3     NaN     18-24  10.550308  49.394050
4  4     NaN     18-24   2.689994  72.304208
In [3]:
target_df.head().round(2).to_dict()
# sample_df.shape
Out[3]:
{'id': {0: '100000', 1: '100001', 2: '100002', 3: '100003', 4: '100004'},
 'gender': {0: 'Male', 1: 'Male', 2: 'Male', 3: nan, 4: nan},
 'age_group': {0: '45+', 1: '45+', 2: '35-44', 3: '45+', 4: '25-34'},
 'income': {0: 10.18, 1: 6.04, 2: 5.23, 3: 5.75, 4: 4.84},
 'happiness': {0: 61.71, 1: 79.12, 2: 44.21, 3: 83.99, 4: 49.34}}

In practice, one can use pandas loading function(such as read_csv()) to import data into the DataFrame objects sample_df and target_df.

Load data into a Sample object¶

The first thing to do is to import the Sample class from balance. All of the data we're going to be working with, sample or population, will be stored in objects of the Sample class.

In [4]:
from balance import Sample

Using the Sample class, we can fill it with a "sample" we want to adjust, and also a "target" we want to adjust towards.

We turn the two input pandas DataFrame objects we created (or loaded) into a balance.Sample objects, by using the .from_frame()

In [5]:
sample = Sample.from_frame(sample_df, outcome_columns=["happiness"])
target = Sample.from_frame(target_df, outcome_columns=["happiness"])
WARNING (2024-11-25 18:55:26,612) [util/guess_id_column (line 114)]: Guessed id column name id for the data
WARNING (2024-11-25 18:55:26,618) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1
WARNING (2024-11-25 18:55:26,624) [util/guess_id_column (line 114)]: Guessed id column name id for the data
WARNING (2024-11-25 18:55:26,635) [sample_class/from_frame (line 261)]: No weights passed. Adding a 'weight' column and setting all values to 1

If we use the .df property call, we can see the DataFrame stored in sample. We can see how we have a new weight column that was added (it will all have 1s) in the importing of the DataFrames into a balance.Sample object.

In [6]:
sample.df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 6 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   id         1000 non-null   object 
 1   gender     912 non-null    object 
 2   age_group  1000 non-null   object 
 3   income     1000 non-null   float64
 4   happiness  1000 non-null   float64
 5   weight     1000 non-null   int64  
dtypes: float64(2), int64(1), object(3)
memory usage: 47.0+ KB

We can get a quick overview text of each Sample object, but just calling it.

Let's take a look at what this produces:

In [7]:
sample
Out[7]:
(balance.sample_class.Sample)

        balance Sample object
        1000 observations x 3 variables: gender,age_group,income
        id_column: id, weight_column: weight,
        outcome_columns: happiness
        
In [8]:
target
Out[8]:
(balance.sample_class.Sample)

        balance Sample object
        10000 observations x 3 variables: gender,age_group,income
        id_column: id, weight_column: weight,
        outcome_columns: happiness
        

Next, we combine the sample object with the target object. This is what will allow us to adjust the sample to the target.

In [9]:
sample_with_target = sample.set_target(target)

Looking on sample_with_target now, it has the target atteched:

In [10]:
sample_with_target
Out[10]:
(balance.sample_class.Sample)

        balance Sample object with target set
        1000 observations x 3 variables: gender,age_group,income
        id_column: id, weight_column: weight,
        outcome_columns: happiness
        
            target:
                 
	        balance Sample object
	        10000 observations x 3 variables: gender,age_group,income
	        id_column: id, weight_column: weight,
	        outcome_columns: happiness
	        
            3 common variables: gender,age_group,income
            

Pre-Adjustment Diagnostics¶

We can use .covars() and then followup with .mean() and .plot() (barplots and qqplots) to get some basic diagnostics on what we got.

We can see how:

  • The proportion of missing values in gender is similar in sample and target.
  • We have younger people in the sample as compared to the target.
  • We have more females than males in the sample, as compared to around 50-50 split for the (non NA) target.
  • Income is more right skewed in the target as compared to the sample.
In [11]:
print(sample_with_target.covars().mean().T)
source                     self     target
_is_na_gender[T.True]  0.088000   0.089800
age_group[T.25-34]     0.300000   0.297400
age_group[T.35-44]     0.156000   0.299200
age_group[T.45+]       0.053000   0.206300
gender[Female]         0.268000   0.455100
gender[Male]           0.644000   0.455100
gender[_NA]            0.088000   0.089800
income                 6.297302  12.737608
In [12]:
print(sample_with_target.covars().asmd().T)
source                  self
age_group[T.25-34]  0.005688
age_group[T.35-44]  0.312711
age_group[T.45+]    0.378828
gender[Female]      0.375699
gender[Male]        0.379314
gender[_NA]         0.006296
income              0.494217
mean(asmd)          0.326799
In [13]:
print(sample_with_target.covars().asmd(aggregate_by_main_covar = True).T)
source          self
age_group   0.232409
gender      0.253769
income      0.494217
mean(asmd)  0.326799
In [14]:
sample_with_target.covars().plot()

Adjusting Sample to Population (ipw and cbps)¶

Next, we adjust the sample to the target. The default method to be used is 'ipw' (which uses inverse probability/propensity weights, after running logistic regression with lasso regularization).

In [15]:
# Using ipw to fit survey weights
adjusted_ipw = sample_with_target.adjust()
INFO (2024-11-25 18:55:27,356) [ipw/ipw (line 424)]: Starting ipw function
INFO (2024-11-25 18:55:27,359) [adjustment/apply_transformations (line 306)]: Adding the variables: []
INFO (2024-11-25 18:55:27,359) [adjustment/apply_transformations (line 307)]: Transforming the variables: ['gender', 'age_group', 'income']
INFO (2024-11-25 18:55:27,371) [adjustment/apply_transformations (line 347)]: Final variables in output: ['gender', 'age_group', 'income']
INFO (2024-11-25 18:55:27,381) [ipw/ipw (line 458)]: Building model matrix
INFO (2024-11-25 18:55:27,535) [ipw/ipw (line 482)]: The formula used to build the model matrix: ['income + gender + age_group + _is_na_gender']
INFO (2024-11-25 18:55:27,536) [ipw/ipw (line 485)]: The number of columns in the model matrix: 16
INFO (2024-11-25 18:55:27,536) [ipw/ipw (line 486)]: The number of rows in the model matrix: 11000
INFO (2024-11-25 18:55:27,543) [ipw/ipw (line 517)]: Fitting logistic model
INFO (2024-11-25 18:55:28,775) [ipw/ipw (line 558)]: max_de: None
INFO (2024-11-25 18:55:28,779) [ipw/ipw (line 588)]: Chosen lambda for cv: [0.0131066]
INFO (2024-11-25 18:55:28,781) [ipw/ipw (line 596)]: Proportion null deviance explained [0.17168419]
In [16]:
adjusted_cbps = sample_with_target.adjust(method = "cbps")
INFO (2024-11-25 18:55:28,791) [cbps/cbps (line 411)]: Starting cbps function
INFO (2024-11-25 18:55:28,794) [adjustment/apply_transformations (line 306)]: Adding the variables: []
INFO (2024-11-25 18:55:28,795) [adjustment/apply_transformations (line 307)]: Transforming the variables: ['gender', 'age_group', 'income']
INFO (2024-11-25 18:55:28,806) [adjustment/apply_transformations (line 347)]: Final variables in output: ['gender', 'age_group', 'income']
INFO (2024-11-25 18:55:28,921) [cbps/cbps (line 462)]: The formula used to build the model matrix: ['income + gender + age_group + _is_na_gender']
INFO (2024-11-25 18:55:28,923) [cbps/cbps (line 474)]: The number of columns in the model matrix: 16
INFO (2024-11-25 18:55:28,923) [cbps/cbps (line 475)]: The number of rows in the model matrix: 11000
INFO (2024-11-25 18:55:28,938) [cbps/cbps (line 537)]: Finding initial estimator for GMM optimization
INFO (2024-11-25 18:55:29,064) [cbps/cbps (line 564)]: Finding initial estimator for GMM optimization that minimizes the balance loss
WARNING (2024-11-25 18:55:29,451) [cbps/cbps (line 581)]: Convergence of bal_loss function has failed due to 'Maximum number of function evaluations has been exceeded.'
INFO (2024-11-25 18:55:29,452) [cbps/cbps (line 599)]: Running GMM optimization
WARNING (2024-11-25 18:55:29,986) [cbps/cbps (line 614)]: Convergence of gmm_loss function with gmm_init start point has failed due to 'Maximum number of function evaluations has been exceeded.'
WARNING (2024-11-25 18:55:30,515) [cbps/cbps (line 632)]: Convergence of gmm_loss function with beta_balance start point has failed due to 'Maximum number of function evaluations has been exceeded.'
INFO (2024-11-25 18:55:30,521) [cbps/cbps (line 730)]: Done cbps function
In [17]:
print(adjusted_ipw)
        Adjusted balance Sample object with target set using ipw
        1000 observations x 3 variables: gender,age_group,income
        id_column: id, weight_column: weight,
        outcome_columns: happiness
        
            target:
                 
	        balance Sample object
	        10000 observations x 3 variables: gender,age_group,income
	        id_column: id, weight_column: weight,
	        outcome_columns: happiness
	        
            3 common variables: gender,age_group,income
            
In [18]:
# the adjusted object will look the same as ipw 
print(adjusted_cbps)
        Adjusted balance Sample object with target set using cbps
        1000 observations x 3 variables: gender,age_group,income
        id_column: id, weight_column: weight,
        outcome_columns: happiness
        
            target:
                 
	        balance Sample object
	        10000 observations x 3 variables: gender,age_group,income
	        id_column: id, weight_column: weight,
	        outcome_columns: happiness
	        
            3 common variables: gender,age_group,income
            

Evaluation of the Results (CBPS vs IPW)¶

We can get a basic summary of the results:

In [19]:
print(adjusted_ipw.summary())
Covar ASMD reduction: 59.7%, design effect: 1.897
Covar ASMD (7 variables): 0.327 -> 0.132
Model performance: Model proportion deviance explained: 0.172
In [20]:
print(adjusted_cbps.summary())
Covar ASMD reduction: 77.6%, design effect: 2.782
Covar ASMD (7 variables): 0.327 -> 0.073

We can see that CBPS did a better job in terms of ASMD reduction. Let's look at it per feature:

We see an improvement in the average ASMD. We can look at detailed list of ASMD values per variables using the following call.

In [21]:
print("ipw:")
print(adjusted_ipw.covars().asmd().T)
print("\ncbps:")
print(adjusted_cbps.covars().asmd().T)
ipw:
source                  self  unadjusted  unadjusted - self
age_group[T.25-34]  0.001085    0.005688           0.004602
age_group[T.35-44]  0.037455    0.312711           0.275256
age_group[T.45+]    0.129304    0.378828           0.249525
gender[Female]      0.133970    0.375699           0.241730
gender[Male]        0.109697    0.379314           0.269617
gender[_NA]         0.042278    0.006296          -0.035983
income              0.243762    0.494217           0.250455
mean(asmd)          0.131675    0.326799           0.195124

cbps:
source                  self  unadjusted  unadjusted - self
age_group[T.25-34]  0.051879    0.005688          -0.046192
age_group[T.35-44]  0.031114    0.312711           0.281597
age_group[T.45+]    0.105655    0.378828           0.273173
gender[Female]      0.034514    0.375699           0.341185
gender[Male]        0.058580    0.379314           0.320733
gender[_NA]         0.041919    0.006296          -0.035624
income              0.111468    0.494217           0.382749
mean(asmd)          0.073118    0.326799           0.253680

It's easier to learn about the biases by just running .covars().plot() on our adjusted object.

In [22]:
adjusted_ipw.covars().plot(library = "seaborn", dist_type = "kde")
In [23]:
adjusted_cbps.covars().plot(library = "seaborn", dist_type = "kde")

We can also use different plots, using the seaborn library, for example with the "kde" dist_type.

Understanding the weights¶

And get the design effect using:

In [24]:
print("ipw:")
print(adjusted_ipw.weights().design_effect())
print("\ncbps:")
print(adjusted_cbps.weights().design_effect())
ipw:
1.8973847221820574

cbps:
2.7816765614638572

Outcome analysis¶

In [25]:
print(adjusted_ipw.outcomes().summary())
adjusted_ipw.outcomes().plot()
1 outcomes: ['happiness']
Mean outcomes (with 95% confidence intervals):
source       self  target  unadjusted           self_ci         target_ci     unadjusted_ci
happiness  53.389  56.278      48.559  (52.183, 54.595)  (55.961, 56.595)  (47.669, 49.449)

Response rates (relative to number of respondents in sample):
   happiness
n     1000.0
%      100.0
Response rates (relative to notnull rows in the target):
    happiness
n     1000.0
%       10.0
Response rates (in the target):
    happiness
n    10000.0
%      100.0

The estimated mean happiness according to our sample is 48 without any adjustment and 54 with adjustment. The following show the distribution of happinnes:

In [26]:
print(adjusted_cbps.outcomes().summary())
adjusted_cbps.outcomes().plot()
1 outcomes: ['happiness']
Mean outcomes (with 95% confidence intervals):
source       self  target  unadjusted          self_ci         target_ci     unadjusted_ci
happiness  54.389  56.278      48.559  (53.02, 55.757)  (55.961, 56.595)  (47.669, 49.449)

Response rates (relative to number of respondents in sample):
   happiness
n     1000.0
%      100.0
Response rates (relative to notnull rows in the target):
    happiness
n     1000.0
%       10.0
Response rates (in the target):
    happiness
n    10000.0
%      100.0

As we can see, CBPS has a larger design effect, but also fixes more of the ASMD and has an impact on the outcome. So there are pros and cons for each of the two methods.

Downloading data¶

Finally, we can prepare the data to be downloaded for future analyses.

In [27]:
adjusted_cbps.to_download()
Out[27]:
Click here to download: /tmp/tmp_balance_out_28f0f7ee-ba24-43fb-9ba8-b4a7f7a78f18.csv
In [28]:
# We can prepare the data to be exported as csv - showing the first 500 charaacters for simplicity:
adjusted_cbps.to_csv()[0:500]
Out[28]:
'id,gender,age_group,income,happiness,weight\n0,Male,25-34,6.428659499046228,26.043028759747298,5.093908021189902\n1,Female,18-24,9.940280228116047,66.88548460632677,0.4137864502389744\n2,Male,18-24,2.6736231547518043,37.091921916683006,2.255219921002779\n3,,18-24,10.550307519418066,49.39405003271002,4.974470708135918\n4,,18-24,2.689993854299385,72.30420755038209,3.343868455923355\n5,,35-44,5.995497722733131,57.28281646341816,17.083435577163435\n6,,18-24,12.63469573898972,31.663293445944596,5.5913639935'
In [29]:
# Sessions info
import session_info
session_info.show(html=False, dependencies=True)
-----
balance             0.9.1
pandas              1.4.3
session_info        1.0.0
-----
PIL                         11.0.0
anyio                       NA
apport_python_hook          NA
argcomplete                 NA
arrow                       1.3.0
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
babel                       2.16.0
beta_ufunc                  NA
binom_ufunc                 NA
certifi                     2020.06.20
chardet                     4.0.0
charset_normalizer          3.4.0
colorama                    0.4.4
comm                        0.2.2
coxnet                      NA
cvcompute                   NA
cvelnet                     NA
cvfishnet                   NA
cvglmnet                    NA
cvglmnetCoef                NA
cvglmnetPredict             NA
cvlognet                    NA
cvmrelnet                   NA
cvmultnet                   NA
cycler                      0.12.1
cython_runtime              NA
dateutil                    2.9.0.post0
debugpy                     1.8.9
decorator                   5.1.1
defusedxml                  0.7.1
elnet                       NA
exceptiongroup              1.2.2
executing                   2.1.0
fastjsonschema              NA
fishnet                     NA
fqdn                        NA
gi                          3.42.1
gio                         NA
glib                        NA
glmnet                      NA
glmnetCoef                  NA
glmnetControl               NA
glmnetPredict               NA
glmnetSet                   NA
glmnet_python               NA
gobject                     NA
gtk                         NA
hypergeom_ufunc             NA
idna                        3.3
ipfn                        NA
ipykernel                   6.29.5
isoduration                 NA
jedi                        0.19.2
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.28
jsonpointer                 2.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
loadGlmLib                  NA
lognet                      NA
markupsafe                  2.0.1
matplotlib                  3.9.2
matplotlib_inline           0.1.7
mpl_toolkits                NA
mrelnet                     NA
nbformat                    5.10.4
nbinom_ufunc                NA
ncf_ufunc                   NA
numpy                       1.24.4
overrides                   NA
packaging                   24.2
parso                       0.8.4
patsy                       1.0.1
platformdirs                4.3.6
plotly                      5.24.1
prometheus_client           NA
prompt_toolkit              3.0.48
psutil                      6.1.0
pure_eval                   0.2.3
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.2.3
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pygments                    2.18.0
pyparsing                   2.4.7
pythonjsonlogger            NA
pytz                        2022.1
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rpds                        NA
scipy                       1.9.1
seaborn                     0.13.0
send2trash                  NA
sitecustomize               NA
six                         1.16.0
sklearn                     1.5.2
sniffio                     1.3.1
sphinxcontrib               NA
stack_data                  0.6.3
statsmodels                 0.14.4
tenacity                    NA
threadpoolctl               3.5.0
tornado                     6.4.2
traitlets                   5.14.3
typing_extensions           NA
uri_template                NA
urllib3                     1.26.5
wcwidth                     0.2.13
webcolors                   NA
websocket                   1.8.0
wtmean                      NA
yaml                        5.4.1
zmq                         26.2.0
zoneinfo                    NA
zope                        NA
-----
IPython             8.29.0
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.6
notebook            7.2.2
-----
Python 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
Linux-6.5.0-1025-azure-x86_64-with-glibc2.35
-----
Session information updated at 2024-11-25 18:55
In [28]:
 
In [28]: